{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MNIST Image Classification with a CNN Model\n",
    "\n",
    "### Build the dataset and dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "from torchvision import transforms\n",
    "\n",
    "# Define the transformations to apply to the data\n",
    "transform = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.5,), (0.5,))\n",
    "])\n",
    "\n",
    "# Load the MNIST dataset\n",
    "train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)\n",
    "test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)\n",
    "\n",
    "# Create the data loaders\n",
    "batch_size = 64\n",
    "train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
    "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize samples from the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# Get 16 random indices\n",
    "indices = np.random.choice(len(train_dataset), size=16, replace=False)\n",
    "\n",
    "# Get the corresponding images and labels\n",
    "images = [train_dataset[i][0] for i in indices]\n",
    "labels = [train_dataset[i][1] for i in indices]\n",
    "\n",
    "# Create a figure with 4x4 subplots\n",
    "fig, axes = plt.subplots(4, 4, figsize=(10, 10))\n",
    "\n",
    "# Iterate over the subplots and display the images with labels\n",
    "for i, ax in enumerate(axes.flat):\n",
    "    ax.imshow(images[i].squeeze(), cmap='gray')\n",
    "    ax.set_title(f\"Label: {labels[i]}\")\n",
    "    ax.axis('off')\n",
    "\n",
    "# Adjust the spacing between subplots\n",
    "plt.tight_layout()\n",
    "\n",
    "# Show the figure\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Build the CNN Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "# 3-layer CNN model\n",
    "class CNN(nn.Module):\n",
    "    def __init__(self, num_classes=10):\n",
    "        super(CNN, self).__init__()\n",
    "        self.features = nn.Sequential(\n",
    "            nn.Conv2d(1, 16, kernel_size=3),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(kernel_size=2),\n",
    "            nn.Conv2d(16, 32, kernel_size=3),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(kernel_size=2),\n",
    "        )\n",
    "        self.classifier = nn.Sequential(\n",
    "            nn.Linear(800, 128),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(128, num_classes)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        y = self.features(x)\n",
    "        y = torch.flatten(y, 1)\n",
    "        y = self.classifier(y)\n",
    "        \n",
    "        return y\n",
    "\n",
    "\n",
    "# Create an instance of the CNN model\n",
    "model = CNN()\n",
    "x = torch.randn(64, 1, 28, 28)\n",
    "print(model)\n",
    "print(model(x).shape)\n",
    "\n",
    "# print the number of parameters\n",
    "num_params = sum(p.numel() for p in model.parameters())\n",
    "# Use comma to print the number in a more readable format\n",
    "print(f\"Number of parameters: {num_params:,}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define the loss function and optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "model.to(device)\n",
    "\n",
    "# Define the loss function and optimizer\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "num_epochs = 10\n",
    "\n",
    "model.train()\n",
    "for epoch in tqdm(range(num_epochs)):\n",
    "    total_loss = 0\n",
    "    \n",
    "    for images, labels in train_loader:\n",
    "        images = images.to(device)\n",
    "        labels = labels.to(device)\n",
    "        \n",
    "        # Forward pass\n",
    "        outputs = model(images)\n",
    "        loss = criterion(outputs, labels)\n",
    "        \n",
    "        # Backward and optimize\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        total_loss += loss.item()\n",
    "    \n",
    "    avg_loss = total_loss / len(train_loader)\n",
    "    print(f\"Epoch: {epoch+1}, Loss: {avg_loss:.4f}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluate the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()  # Set the model to evaluation mode\n",
    "\n",
    "correct = 0\n",
    "total = 0\n",
    "\n",
    "with torch.no_grad():\n",
    "    for images, labels in test_loader:\n",
    "        images = images.to(device)\n",
    "        labels = labels.to(device)\n",
    "        \n",
    "        # Forward pass\n",
    "        outputs = model(images)\n",
    "        _, predicted = torch.max(outputs.data, 1)\n",
    "        \n",
    "        total += labels.size(0)\n",
    "        correct += (predicted == labels).sum().item()\n",
    "\n",
    "accuracy = correct / total\n",
    "print(f\"Test Accuracy: {accuracy:.4f}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Predict on sample images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "\n",
    "# Set the model to evaluation mode\n",
    "model.eval()\n",
    "\n",
    "# Select a random image from the test dataset\n",
    "random_index = random.randint(0, len(test_dataset) - 1)\n",
    "image, label = test_dataset[random_index]\n",
    "\n",
    "# Move the image to the device\n",
    "image = image.to(device)\n",
    "\n",
    "# Forward pass to get the predicted label\n",
    "output = model(image.unsqueeze(0))\n",
    "_, predicted_label = torch.max(output, 1)\n",
    "\n",
    "# Convert the image tensor to a numpy array\n",
    "image_np = image.cpu().numpy()\n",
    "\n",
    "# Display the image, its label, and the predicted label\n",
    "plt.imshow(image_np.squeeze(), cmap='gray')\n",
    "plt.title(f\"Label: {label}, Predicted: {predicted_label.item()}\")\n",
    "plt.axis('off')\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "fuyu",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
